{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 泰坦尼克号生存预测——模型选择与训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LIJIANcoder97"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 上一次作业已经将训练集的数据清理完毕了本节将继续清理测试集的数据并使用 KNN，逻辑回归，支持向量机模型进行预测，本文主要研究KNN预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 调用sklearn LabelEncoder()，进行编码，文本转数字。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_features(data_set, feature_names):\n",
    "    for feature_name in feature_names:\n",
    "        le = LabelEncoder()\n",
    "        le.fit(data_set[feature_name])\n",
    "        encoded_column = le.transform(data_set[feature_name])\n",
    "        data_set[feature_name] = encoded_column\n",
    "    return data_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_train=pd.read_csv('dataset/titanic/train.csv')\n",
    "Taitan_test=pd.read_csv('dataset/titanic/test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "丢弃'Cabin','Name','Ticket'属性列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_train=Taitan_train.drop(['Cabin','Name','Ticket'],axis=1)\n",
    "Taitan_test=Taitan_test.drop(['Cabin','Name','Ticket'],axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 对sex embark属性进行编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "le = LabelEncoder()\n",
    "le.fit(Taitan_train['Sex'])\n",
    "Taitan_train['Sex']=le.transform(Taitan_train['Sex']) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "le = LabelEncoder()\n",
    "le.fit(Taitan_test['Sex'])\n",
    "Taitan_test['Sex']=le.transform(Taitan_test['Sex']) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Embarked 属性有缺失值，要先进行处理，缺失值替换为最多的登船口"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_train['Embarked'] = Taitan_train['Embarked'].fillna('S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['C', 'Q', 'S'], dtype=object)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l = LabelEncoder()\n",
    "l.fit(Taitan_train['Embarked'])\n",
    "l.classes_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_train['Embarked']=l.transform(Taitan_train['Embarked']) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_test['Embarked'] = Taitan_test['Embarked'].fillna('S')\n",
    "l = LabelEncoder()\n",
    "l.fit(Taitan_test['Embarked'])\n",
    "Taitan_test['Embarked']=l.transform(Taitan_test['Embarked']) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 编码完毕\n",
    "### 对Age 属性进行处理，缺失值替换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_train['Age'] = Taitan_train['Age'].fillna(Taitan_train['Age'].median())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_test['Age'] = Taitan_test['Age'].fillna(Taitan_test['Age'].median())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 创建family属性代表乘客的亲人数目探索其与存活率的关系"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_train['Family']=Taitan_train['SibSp']+Taitan_train['Parch']+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_test['Family']=Taitan_test['SibSp']+Taitan_test['Parch']+1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在已经有了清理好的数据集了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Embarked</th>\n",
       "      <th>Family</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Survived  Pclass  Sex   Age  SibSp  Parch     Fare  Embarked  \\\n",
       "0            1         0       3    1  22.0      1      0   7.2500         2   \n",
       "1            2         1       1    0  38.0      1      0  71.2833         0   \n",
       "2            3         1       3    0  26.0      0      0   7.9250         2   \n",
       "3            4         1       1    0  35.0      1      0  53.1000         2   \n",
       "4            5         0       3    1  35.0      0      0   8.0500         2   \n",
       "\n",
       "   Family  \n",
       "0       2  \n",
       "1       2  \n",
       "2       1  \n",
       "3       2  \n",
       "4       1  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Taitan_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Embarked</th>\n",
       "      <th>Family</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>892</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>34.5</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.8292</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>893</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>47.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.0000</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>894</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>62.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>9.6875</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>895</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>27.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.6625</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>896</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>12.2875</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Pclass  Sex   Age  SibSp  Parch     Fare  Embarked  Family\n",
       "0          892       3    1  34.5      0      0   7.8292         1       1\n",
       "1          893       3    0  47.0      1      0   7.0000         2       2\n",
       "2          894       2    1  62.0      0      0   9.6875         1       1\n",
       "3          895       3    1  27.0      0      0   8.6625         2       1\n",
       "4          896       3    0  22.0      1      1  12.2875         2       3"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Taitan_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_p=Taitan_train['Survived']\n",
    "Taitan_train=Taitan_train.drop(['Survived'],axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练\n",
    "## KNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入knn和准确率评估方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先通过网格搜索选出最合适的超参数，将训练集分为两部分一部分用于训练，一部分用于验证。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train=Taitan_train[0:712]\n",
    "x_test=Taitan_train[712: ]\n",
    "y_train=Taitan_p[0:712]\n",
    "y_test=Taitan_p[712: ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练集分割完毕开始训练,使用网格搜索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k= 2 ;p= 1 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.70      0.91      0.79       115\n",
      "           1       0.64      0.28      0.39        64\n",
      "\n",
      "   micro avg       0.69      0.69      0.69       179\n",
      "   macro avg       0.67      0.60      0.59       179\n",
      "weighted avg       0.68      0.69      0.65       179\n",
      "\n",
      "k= 2 ;p= 1 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.73      0.72      0.73       115\n",
      "           1       0.52      0.53      0.52        64\n",
      "\n",
      "   micro avg       0.65      0.65      0.65       179\n",
      "   macro avg       0.62      0.63      0.63       179\n",
      "weighted avg       0.66      0.65      0.65       179\n",
      "\n",
      "k= 2 ;p= 2 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.65      0.95      0.77       115\n",
      "           1       0.50      0.09      0.16        64\n",
      "\n",
      "   micro avg       0.64      0.64      0.64       179\n",
      "   macro avg       0.58      0.52      0.47       179\n",
      "weighted avg       0.60      0.64      0.55       179\n",
      "\n",
      "k= 2 ;p= 2 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.66      0.57      0.61       115\n",
      "           1       0.38      0.47      0.42        64\n",
      "\n",
      "   micro avg       0.54      0.54      0.54       179\n",
      "   macro avg       0.52      0.52      0.52       179\n",
      "weighted avg       0.56      0.54      0.54       179\n",
      "\n",
      "k= 3 ;p= 1 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.77      0.80      0.79       115\n",
      "           1       0.62      0.58      0.60        64\n",
      "\n",
      "   micro avg       0.72      0.72      0.72       179\n",
      "   macro avg       0.69      0.69      0.69       179\n",
      "weighted avg       0.72      0.72      0.72       179\n",
      "\n",
      "k= 3 ;p= 1 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.77      0.80      0.79       115\n",
      "           1       0.62      0.58      0.60        64\n",
      "\n",
      "   micro avg       0.72      0.72      0.72       179\n",
      "   macro avg       0.69      0.69      0.69       179\n",
      "weighted avg       0.72      0.72      0.72       179\n",
      "\n",
      "k= 3 ;p= 2 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.70      0.30      0.42       115\n",
      "           1       0.38      0.77      0.51        64\n",
      "\n",
      "   micro avg       0.47      0.47      0.47       179\n",
      "   macro avg       0.54      0.53      0.47       179\n",
      "weighted avg       0.59      0.47      0.45       179\n",
      "\n",
      "k= 3 ;p= 2 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.70      0.30      0.42       115\n",
      "           1       0.38      0.77      0.51        64\n",
      "\n",
      "   micro avg       0.47      0.47      0.47       179\n",
      "   macro avg       0.54      0.53      0.47       179\n",
      "weighted avg       0.59      0.47      0.45       179\n",
      "\n",
      "k= 4 ;p= 1 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.71      0.90      0.80       115\n",
      "           1       0.67      0.34      0.45        64\n",
      "\n",
      "   micro avg       0.70      0.70      0.70       179\n",
      "   macro avg       0.69      0.62      0.63       179\n",
      "weighted avg       0.70      0.70      0.67       179\n",
      "\n",
      "k= 4 ;p= 1 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.76      0.82      0.79       115\n",
      "           1       0.62      0.53      0.57        64\n",
      "\n",
      "   micro avg       0.72      0.72      0.72       179\n",
      "   macro avg       0.69      0.67      0.68       179\n",
      "weighted avg       0.71      0.72      0.71       179\n",
      "\n",
      "k= 4 ;p= 2 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.68      0.46      0.55       115\n",
      "           1       0.39      0.61      0.47        64\n",
      "\n",
      "   micro avg       0.51      0.51      0.51       179\n",
      "   macro avg       0.53      0.54      0.51       179\n",
      "weighted avg       0.57      0.51      0.52       179\n",
      "\n",
      "k= 4 ;p= 2 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.71      0.30      0.43       115\n",
      "           1       0.38      0.78      0.52        64\n",
      "\n",
      "   micro avg       0.47      0.47      0.47       179\n",
      "   macro avg       0.55      0.54      0.47       179\n",
      "weighted avg       0.60      0.47      0.46       179\n",
      "\n",
      "k= 5 ;p= 1 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.75      0.80      0.77       115\n",
      "           1       0.59      0.52      0.55        64\n",
      "\n",
      "   micro avg       0.70      0.70      0.70       179\n",
      "   macro avg       0.67      0.66      0.66       179\n",
      "weighted avg       0.69      0.70      0.69       179\n",
      "\n",
      "k= 5 ;p= 1 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.75      0.80      0.77       115\n",
      "           1       0.59      0.52      0.55        64\n",
      "\n",
      "   micro avg       0.70      0.70      0.70       179\n",
      "   macro avg       0.67      0.66      0.66       179\n",
      "weighted avg       0.69      0.70      0.69       179\n",
      "\n",
      "k= 5 ;p= 2 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.64      0.26      0.37       115\n",
      "           1       0.36      0.73      0.48        64\n",
      "\n",
      "   micro avg       0.43      0.43      0.43       179\n",
      "   macro avg       0.50      0.50      0.42       179\n",
      "weighted avg       0.54      0.43      0.41       179\n",
      "\n",
      "k= 5 ;p= 2 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.64      0.26      0.37       115\n",
      "           1       0.36      0.73      0.48        64\n",
      "\n",
      "   micro avg       0.43      0.43      0.43       179\n",
      "   macro avg       0.50      0.50      0.42       179\n",
      "weighted avg       0.54      0.43      0.41       179\n",
      "\n",
      "k= 6 ;p= 1 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.71      0.86      0.78       115\n",
      "           1       0.60      0.38      0.46        64\n",
      "\n",
      "   micro avg       0.69      0.69      0.69       179\n",
      "   macro avg       0.66      0.62      0.62       179\n",
      "weighted avg       0.67      0.69      0.67       179\n",
      "\n",
      "k= 6 ;p= 1 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.74      0.83      0.78       115\n",
      "           1       0.61      0.48      0.54        64\n",
      "\n",
      "   micro avg       0.70      0.70      0.70       179\n",
      "   macro avg       0.68      0.66      0.66       179\n",
      "weighted avg       0.69      0.70      0.70       179\n",
      "\n",
      "k= 6 ;p= 2 ;w= uniform ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.72      0.63      0.67       115\n",
      "           1       0.46      0.56      0.50        64\n",
      "\n",
      "   micro avg       0.60      0.60      0.60       179\n",
      "   macro avg       0.59      0.59      0.59       179\n",
      "weighted avg       0.63      0.60      0.61       179\n",
      "\n",
      "k= 6 ;p= 2 ;w= distance ;结果:               precision    recall  f1-score   support\n",
      "\n",
      "           0       0.69      0.25      0.37       115\n",
      "           1       0.37      0.80      0.51        64\n",
      "\n",
      "   micro avg       0.45      0.45      0.45       179\n",
      "   macro avg       0.53      0.52      0.44       179\n",
      "weighted avg       0.58      0.45      0.42       179\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for k in [2,3,4,5,6]:\n",
    "    for p in [1,2]: ##1代表明可夫斯基距离，2代表欧拉距离\n",
    "        for w in ['uniform','distance']: ##uniform 代表不考虑距离权重\n",
    "            n= KNeighborsClassifier(n_neighbors=k,weights=w,p=p)\n",
    "            n.fit(x_train,y_train)##构造模型！！\n",
    "            y_p=n.predict(x_test)##预测\n",
    "            print('k=',k,';p=',p,';w=',w,';结果:',classification_report(y_test,y_p))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "准确率的参考属性时micro"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "准确率最高为 0.72，分别是 3，1，uniform；3，1，distance；4，1，distance\n",
    "选择4，1，distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Input contains NaN, infinity or a value too large for dtype('float64').",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-33-507fa572df29>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m=\u001b[0m \u001b[0mKNeighborsClassifier\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn_neighbors\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mweights\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'distance'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mTaitan_train\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mTaitan_p\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;31m##构造模型！！\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0my_re\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mTaitan_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;31m##预测\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\sklearn\\neighbors\\classification.py\u001b[0m in \u001b[0;36mpredict\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m    145\u001b[0m             \u001b[0mClass\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0meach\u001b[0m \u001b[0mdata\u001b[0m \u001b[0msample\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    146\u001b[0m         \"\"\"\n\u001b[1;32m--> 147\u001b[1;33m         \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'csr'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    148\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    149\u001b[0m         \u001b[0mneigh_dist\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mneigh_ind\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkneighbors\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)\u001b[0m\n\u001b[0;32m    571\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mforce_all_finite\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    572\u001b[0m             _assert_all_finite(array,\n\u001b[1;32m--> 573\u001b[1;33m                                allow_nan=force_all_finite == 'allow-nan')\n\u001b[0m\u001b[0;32m    574\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    575\u001b[0m     \u001b[0mshape_repr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_shape_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py\u001b[0m in \u001b[0;36m_assert_all_finite\u001b[1;34m(X, allow_nan)\u001b[0m\n\u001b[0;32m     54\u001b[0m                 not allow_nan and not np.isfinite(X).all()):\n\u001b[0;32m     55\u001b[0m             \u001b[0mtype_err\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'infinity'\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mallow_nan\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34m'NaN, infinity'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 56\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmsg_err\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtype_err\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     57\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     58\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mValueError\u001b[0m: Input contains NaN, infinity or a value too large for dtype('float64')."
     ]
    }
   ],
   "source": [
    "n= KNeighborsClassifier(n_neighbors=4,weights='distance',p=1)\n",
    "n.fit(Taitan_train,Taitan_p)##构造模型！！\n",
    "y_re=n.predict(Taitan_test)##预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PassengerId    0\n",
       "Pclass         0\n",
       "Sex            0\n",
       "Age            0\n",
       "SibSp          0\n",
       "Parch          0\n",
       "Fare           1\n",
       "Embarked       0\n",
       "Family         0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Taitan_test.isnull().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 测试集竟然由空值，一定要注意"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_test['Fare'] = Taitan_train['Fare'].fillna(Taitan_train['Fare'].median())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PassengerId    0\n",
       "Pclass         0\n",
       "Sex            0\n",
       "Age            0\n",
       "SibSp          0\n",
       "Parch          0\n",
       "Fare           0\n",
       "Embarked       0\n",
       "Family         0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Taitan_test.isnull().sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "清理完毕再次测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "n= KNeighborsClassifier(n_neighbors=4,weights='distance',p=1)\n",
    "n.fit(Taitan_train,Taitan_p)##构造模型！！\n",
    "y_re=n.predict(Taitan_test)##预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,\n",
       "       0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "       0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0,\n",
       "       0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0,\n",
       "       1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,\n",
       "       1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0,\n",
       "       0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,\n",
       "       1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,\n",
       "       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,\n",
       "       0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,\n",
       "       0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "       0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,\n",
       "       0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1,\n",
       "       0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1,\n",
       "       0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n",
       "       0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1,\n",
       "       0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],\n",
       "      dtype=int64)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_test['Survived']=y_re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_re=Taitan_test[['PassengerId','Survived']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "Taitan_re.to_csv('dataset/titanic/RE.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>892</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>893</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>894</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>895</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>896</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Survived\n",
       "0          892         0\n",
       "1          893         1\n",
       "2          894         0\n",
       "3          895         1\n",
       "4          896         0"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Taitan_re.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>418.000000</td>\n",
       "      <td>418.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>1100.500000</td>\n",
       "      <td>0.351675</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>120.810458</td>\n",
       "      <td>0.478065</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>892.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>996.250000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>1100.500000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>1204.750000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>1309.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       PassengerId    Survived\n",
       "count   418.000000  418.000000\n",
       "mean   1100.500000    0.351675\n",
       "std     120.810458    0.478065\n",
       "min     892.000000    0.000000\n",
       "25%     996.250000    0.000000\n",
       "50%    1100.500000    0.000000\n",
       "75%    1204.750000    1.000000\n",
       "max    1309.000000    1.000000"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Taitan_re.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 分析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "knn的训练与预测完成了，理想情况下，测试集的预测结果准确率也会达到0.72\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上方我们用网格搜索的方式用训练集找到了最合适的超参数，从输出结果可以看出准确率在0.43到0.72件变化，超参数对模型训练结果形像巨大。上方我们超参数的组合共有20种组合，而k和p还可以取更多的值，如果人工调节非常浪费时间，用网格搜索方式可节省时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
